Customer churn is a problem that all companies need to monitor, especially those that depend on subscription-based revenue streams. Customer churn refers to the situation when a customer ends their relationship with a company, and it’s a costly problem. Customers are the fuel that powers a business. Loss of customers impacts sales. Further, it’s much more difficult and costly to gain new customers than it is to retain existing customers. As a result, organizations need to focus on reducing customer churn.
The dataset used for this Keras tutorial is IBM Watson Telco Dataset. According to IBM, the business challenge is:
“A telecommunications company [Telco] is concerned about the number of customers leaving their landline business for cable competitors. They need to understand who is leaving. Imagine that you’re an analyst at this company and you have to find out who is leaving and why.”
We are going to use Keras libraryto develop a sophisticated and highly accurate deep learning model in Python. We walk you through the preprocessing steps, investing time into how to format the data for Keras.
Finally we show you how to get black box (NN) insights using the recently developed lime package.
import pandas as pd
import numpy as np
df = pd.read_csv("../../data/Telco-Customer-Churn.csv")
df
| customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| 1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.5 | No |
| 2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| 3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| 4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 7038 | 6840-RESVB | Male | 0 | Yes | Yes | 24 | Yes | Yes | DSL | Yes | ... | Yes | Yes | Yes | Yes | One year | Yes | Mailed check | 84.80 | 1990.5 | No |
| 7039 | 2234-XADUH | Female | 0 | Yes | Yes | 72 | Yes | Yes | Fiber optic | No | ... | Yes | No | Yes | Yes | One year | Yes | Credit card (automatic) | 103.20 | 7362.9 | No |
| 7040 | 4801-JZAZL | Female | 0 | Yes | Yes | 11 | No | No phone service | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.60 | 346.45 | No |
| 7041 | 8361-LTMKD | Male | 1 | Yes | No | 4 | Yes | Yes | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 74.40 | 306.6 | Yes |
| 7042 | 3186-AJIEK | Male | 0 | No | No | 66 | Yes | No | Fiber optic | Yes | ... | Yes | Yes | Yes | Yes | Two year | Yes | Bank transfer (automatic) | 105.65 | 6844.5 | No |
7043 rows × 21 columns
# It appears that some columns in "TotalCharges" are " " instead of None. Let's remove them.
df = df[df["TotalCharges"] != " "]
df = df.reset_index(drop=True)
df
| customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| 1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.5 | No |
| 2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| 3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| 4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 7027 | 6840-RESVB | Male | 0 | Yes | Yes | 24 | Yes | Yes | DSL | Yes | ... | Yes | Yes | Yes | Yes | One year | Yes | Mailed check | 84.80 | 1990.5 | No |
| 7028 | 2234-XADUH | Female | 0 | Yes | Yes | 72 | Yes | Yes | Fiber optic | No | ... | Yes | No | Yes | Yes | One year | Yes | Credit card (automatic) | 103.20 | 7362.9 | No |
| 7029 | 4801-JZAZL | Female | 0 | Yes | Yes | 11 | No | No phone service | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.60 | 346.45 | No |
| 7030 | 8361-LTMKD | Male | 1 | Yes | No | 4 | Yes | Yes | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 74.40 | 306.6 | Yes |
| 7031 | 3186-AJIEK | Male | 0 | No | No | 66 | Yes | No | Fiber optic | Yes | ... | Yes | Yes | Yes | Yes | Two year | Yes | Bank transfer (automatic) | 105.65 | 6844.5 | No |
7032 rows × 21 columns
df["TotalCharges"].dtype
dtype('O')
# Change the type of columns to more optimal ones (for conveniency, as well as saving space and time)
df["TotalCharges"] = df["TotalCharges"].astype(float)
df["SeniorCitizen"] = df["SeniorCitizen"].astype("category")
for col in df.columns:
if df[col].dtype == "object":
df[col] = df[col].astype("category")
df["tenure"] = df["tenure"].astype(int)
df
| customerID | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | ... | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | Churn | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 7590-VHVEG | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 | No |
| 1 | 5575-GNVDE | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | ... | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.50 | No |
| 2 | 3668-QPYBK | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 | Yes |
| 3 | 7795-CFOCW | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | ... | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 | No |
| 4 | 9237-HQITU | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 | Yes |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 7027 | 6840-RESVB | Male | 0 | Yes | Yes | 24 | Yes | Yes | DSL | Yes | ... | Yes | Yes | Yes | Yes | One year | Yes | Mailed check | 84.80 | 1990.50 | No |
| 7028 | 2234-XADUH | Female | 0 | Yes | Yes | 72 | Yes | Yes | Fiber optic | No | ... | Yes | No | Yes | Yes | One year | Yes | Credit card (automatic) | 103.20 | 7362.90 | No |
| 7029 | 4801-JZAZL | Female | 0 | Yes | Yes | 11 | No | No phone service | DSL | Yes | ... | No | No | No | No | Month-to-month | Yes | Electronic check | 29.60 | 346.45 | No |
| 7030 | 8361-LTMKD | Male | 1 | Yes | No | 4 | Yes | Yes | Fiber optic | No | ... | No | No | No | No | Month-to-month | Yes | Mailed check | 74.40 | 306.60 | Yes |
| 7031 | 3186-AJIEK | Male | 0 | No | No | 66 | Yes | No | Fiber optic | Yes | ... | Yes | Yes | Yes | Yes | Two year | Yes | Bank transfer (automatic) | 105.65 | 6844.50 | No |
7032 rows × 21 columns
df.dtypes
customerID category gender category SeniorCitizen category Partner category Dependents category tenure int32 PhoneService category MultipleLines category InternetService category OnlineSecurity category OnlineBackup category DeviceProtection category TechSupport category StreamingTV category StreamingMovies category Contract category PaperlessBilling category PaymentMethod category MonthlyCharges float64 TotalCharges float64 Churn category dtype: object
df = df.drop(columns = ["customerID"]) # Drop the customer ID column
df = df.dropna() # Drop any row that has at least 1 NaN
df = df.reindex(columns=['Churn'] + list(df.columns.drop('Churn'))) # Bring the churn in front
df
| Churn | gender | SeniorCitizen | Partner | Dependents | tenure | PhoneService | MultipleLines | InternetService | OnlineSecurity | OnlineBackup | DeviceProtection | TechSupport | StreamingTV | StreamingMovies | Contract | PaperlessBilling | PaymentMethod | MonthlyCharges | TotalCharges | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | No | Female | 0 | Yes | No | 1 | No | No phone service | DSL | No | Yes | No | No | No | No | Month-to-month | Yes | Electronic check | 29.85 | 29.85 |
| 1 | No | Male | 0 | No | No | 34 | Yes | No | DSL | Yes | No | Yes | No | No | No | One year | No | Mailed check | 56.95 | 1889.50 |
| 2 | Yes | Male | 0 | No | No | 2 | Yes | No | DSL | Yes | Yes | No | No | No | No | Month-to-month | Yes | Mailed check | 53.85 | 108.15 |
| 3 | No | Male | 0 | No | No | 45 | No | No phone service | DSL | Yes | No | Yes | Yes | No | No | One year | No | Bank transfer (automatic) | 42.30 | 1840.75 |
| 4 | Yes | Female | 0 | No | No | 2 | Yes | No | Fiber optic | No | No | No | No | No | No | Month-to-month | Yes | Electronic check | 70.70 | 151.65 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 7027 | No | Male | 0 | Yes | Yes | 24 | Yes | Yes | DSL | Yes | No | Yes | Yes | Yes | Yes | One year | Yes | Mailed check | 84.80 | 1990.50 |
| 7028 | No | Female | 0 | Yes | Yes | 72 | Yes | Yes | Fiber optic | No | Yes | Yes | No | Yes | Yes | One year | Yes | Credit card (automatic) | 103.20 | 7362.90 |
| 7029 | No | Female | 0 | Yes | Yes | 11 | No | No phone service | DSL | Yes | No | No | No | No | No | Month-to-month | Yes | Electronic check | 29.60 | 346.45 |
| 7030 | Yes | Male | 1 | Yes | No | 4 | Yes | Yes | Fiber optic | No | No | No | No | No | No | Month-to-month | Yes | Mailed check | 74.40 | 306.60 |
| 7031 | No | Male | 0 | No | No | 66 | Yes | No | Fiber optic | Yes | No | Yes | Yes | Yes | Yes | Two year | Yes | Bank transfer (automatic) | 105.65 | 6844.50 |
7032 rows × 20 columns
# Choose the categorical columns in which we will do dummy variables. Remove Churn from those.
categorical_columns = [col for col in df.columns if df[col].dtype.name == "category"] # Get categorical columns
categorical_columns = [col for col in categorical_columns if col != "Churn"] # ignore Churn from the categorical columns
categorical_columns
['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']
from sklearn.model_selection import train_test_split
train_data_raw, test_data_raw = train_test_split(df, test_size=0.2, random_state = 10)
print(f"Train set has: {len(train_data_raw)} rows")
print(f"Test set has: {len(test_data_raw)} rows")
Train set has: 5625 rows Test set has: 1407 rows
train_labels = train_data_raw["Churn"].map({"Yes": 1, "No": 0}) # Make "Yes" and "No" to 1 and 0.
test_labels = test_data_raw["Churn"].map({"Yes": 1, "No": 0})
from sklearn.preprocessing import StandardScaler
def pre_process(df):
# for discretize, right = False is default in R, right = True is default in Python
df['tenure'] = pd.cut(df['tenure'], bins=6, labels = False, right=False) # discretize tenure in 6 categories
df['TotalCharges'] = np.log(df['TotalCharges']) # Make TotalCharges in log scale
df = pd.get_dummies(data = df, columns = categorical_columns + ["tenure"], drop_first = True) # encode columns
scaler = StandardScaler()
df = scaler.fit_transform(df.drop(columns = ["Churn"])) # standardize features
return df
train_data = pre_process(train_data_raw.copy()) # .copy() gets rid of warnings but not essential in new versions of pandas
test_data = pre_process(test_data_raw.copy())
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import initializers
import tensorflow_addons as tfa # for F1 score
C:\ProgramData\Anaconda3\lib\site-packages\tensorflow_addons\utils\tfa_eol_msg.py:23: UserWarning: TensorFlow Addons (TFA) has ended development and introduction of new features. TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024. Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). For more information see: https://github.com/tensorflow/addons/issues/2807 warnings.warn(
model = keras.Sequential(
[
layers.Dense(16, activation="relu", kernel_initializer="uniform", input_dim = train_data.shape[1]),
layers.Dropout(0.1),
layers.Dense(16, activation="relu", kernel_initializer="uniform"),
layers.Dropout(0.1),
layers.Dense(1, kernel_initializer="uniform", activation = "sigmoid")
]
)
model.compile(loss='binary_crossentropy', optimizer='adam', metrics =['accuracy', tf.keras.metrics.AUC(),
keras.metrics.Precision(), keras.metrics.Recall(), tfa.metrics.F1Score(num_classes = 1, threshold = 0.5)])
model.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 16) 560
dropout (Dropout) (None, 16) 0
dense_1 (Dense) (None, 16) 272
dropout_1 (Dropout) (None, 16) 0
dense_2 (Dense) (None, 1) 17
=================================================================
Total params: 849
Trainable params: 849
Non-trainable params: 0
_________________________________________________________________
results = model.fit(x = train_data, y = train_labels, batch_size = 50, epochs = 35, validation_split = 0.30)
Epoch 1/35 79/79 [==============================] - 2s 6ms/step - loss: 0.6292 - accuracy: 0.7285 - auc: 0.5775 - precision: 0.2143 - recall: 0.0057 - f1_score: 0.0111 - val_loss: 0.5026 - val_accuracy: 0.7305 - val_auc: 0.7976 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_f1_score: 0.0000e+00 Epoch 2/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4702 - accuracy: 0.7325 - auc: 0.8084 - precision: 0.0000e+00 - recall: 0.0000e+00 - f1_score: 0.0000e+00 - val_loss: 0.4589 - val_accuracy: 0.7305 - val_auc: 0.8241 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_f1_score: 0.0000e+00 Epoch 3/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4524 - accuracy: 0.7325 - auc: 0.8260 - precision: 0.0000e+00 - recall: 0.0000e+00 - f1_score: 0.0000e+00 - val_loss: 0.4477 - val_accuracy: 0.7305 - val_auc: 0.8318 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_f1_score: 0.0000e+00 Epoch 4/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4455 - accuracy: 0.7325 - auc: 0.8278 - precision: 0.0000e+00 - recall: 0.0000e+00 - f1_score: 0.0000e+00 - val_loss: 0.4421 - val_accuracy: 0.7305 - val_auc: 0.8349 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00 - val_f1_score: 0.0000e+00 Epoch 5/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4405 - accuracy: 0.7450 - auc: 0.8363 - precision: 0.8182 - recall: 0.0598 - f1_score: 0.1115 - val_loss: 0.4384 - val_accuracy: 0.7879 - val_auc: 0.8391 - val_precision: 0.6917 - val_recall: 0.3846 - val_f1_score: 0.4944 Epoch 6/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4366 - accuracy: 0.7963 - auc: 0.8400 - precision: 0.6887 - recall: 0.4349 - f1_score: 0.5332 - val_loss: 0.4358 - val_accuracy: 0.7909 - val_auc: 0.8414 - val_precision: 0.6759 - val_recall: 0.4308 - val_f1_score: 0.5262 Epoch 7/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4330 - accuracy: 0.8011 - auc: 0.8413 - precision: 0.6810 - recall: 0.4824 - f1_score: 0.5648 - val_loss: 0.4346 - val_accuracy: 0.7974 - val_auc: 0.8405 - val_precision: 0.6738 - val_recall: 0.4813 - val_f1_score: 0.5615 Epoch 8/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4331 - accuracy: 0.8016 - auc: 0.8406 - precision: 0.6655 - recall: 0.5195 - f1_score: 0.5835 - val_loss: 0.4332 - val_accuracy: 0.7950 - val_auc: 0.8436 - val_precision: 0.6811 - val_recall: 0.4505 - val_f1_score: 0.5423 Epoch 9/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4296 - accuracy: 0.8016 - auc: 0.8451 - precision: 0.6789 - recall: 0.4900 - f1_score: 0.5692 - val_loss: 0.4326 - val_accuracy: 0.7944 - val_auc: 0.8442 - val_precision: 0.6720 - val_recall: 0.4637 - val_f1_score: 0.5488 Epoch 10/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4286 - accuracy: 0.8037 - auc: 0.8456 - precision: 0.6759 - recall: 0.5109 - f1_score: 0.5819 - val_loss: 0.4316 - val_accuracy: 0.8015 - val_auc: 0.8441 - val_precision: 0.6734 - val_recall: 0.5121 - val_f1_score: 0.5818 Epoch 11/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4305 - accuracy: 0.8024 - auc: 0.8429 - precision: 0.6717 - recall: 0.5109 - f1_score: 0.5804 - val_loss: 0.4310 - val_accuracy: 0.8021 - val_auc: 0.8437 - val_precision: 0.6631 - val_recall: 0.5407 - val_f1_score: 0.5956 Epoch 12/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4246 - accuracy: 0.8039 - auc: 0.8479 - precision: 0.6639 - recall: 0.5404 - f1_score: 0.5958 - val_loss: 0.4306 - val_accuracy: 0.8033 - val_auc: 0.8439 - val_precision: 0.6685 - val_recall: 0.5363 - val_f1_score: 0.5951 Epoch 13/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4297 - accuracy: 0.8047 - auc: 0.8447 - precision: 0.6667 - recall: 0.5394 - f1_score: 0.5963 - val_loss: 0.4299 - val_accuracy: 0.8009 - val_auc: 0.8456 - val_precision: 0.6667 - val_recall: 0.5231 - val_f1_score: 0.5862 Epoch 14/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4264 - accuracy: 0.8052 - auc: 0.8474 - precision: 0.6690 - recall: 0.5375 - f1_score: 0.5961 - val_loss: 0.4292 - val_accuracy: 0.8045 - val_auc: 0.8455 - val_precision: 0.6722 - val_recall: 0.5363 - val_f1_score: 0.5966 Epoch 15/35 79/79 [==============================] - 0s 4ms/step - loss: 0.4233 - accuracy: 0.8024 - auc: 0.8499 - precision: 0.6620 - recall: 0.5337 - f1_score: 0.5910 - val_loss: 0.4284 - val_accuracy: 0.8027 - val_auc: 0.8455 - val_precision: 0.6640 - val_recall: 0.5429 - val_f1_score: 0.5973 Epoch 16/35 79/79 [==============================] - 0s 4ms/step - loss: 0.4211 - accuracy: 0.8059 - auc: 0.8512 - precision: 0.6702 - recall: 0.5404 - f1_score: 0.5983 - val_loss: 0.4282 - val_accuracy: 0.8039 - val_auc: 0.8464 - val_precision: 0.6658 - val_recall: 0.5473 - val_f1_score: 0.6007 Epoch 17/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4213 - accuracy: 0.8034 - auc: 0.8498 - precision: 0.6609 - recall: 0.5442 - f1_score: 0.5969 - val_loss: 0.4292 - val_accuracy: 0.7998 - val_auc: 0.8457 - val_precision: 0.6452 - val_recall: 0.5714 - val_f1_score: 0.6061 Epoch 18/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4207 - accuracy: 0.8072 - auc: 0.8507 - precision: 0.6678 - recall: 0.5556 - f1_score: 0.6065 - val_loss: 0.4278 - val_accuracy: 0.8033 - val_auc: 0.8467 - val_precision: 0.6597 - val_recall: 0.5582 - val_f1_score: 0.6048 Epoch 19/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4221 - accuracy: 0.8059 - auc: 0.8485 - precision: 0.6671 - recall: 0.5480 - f1_score: 0.6017 - val_loss: 0.4274 - val_accuracy: 0.8027 - val_auc: 0.8457 - val_precision: 0.6556 - val_recall: 0.5648 - val_f1_score: 0.6068 Epoch 20/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4201 - accuracy: 0.8057 - auc: 0.8506 - precision: 0.6655 - recall: 0.5499 - f1_score: 0.6022 - val_loss: 0.4270 - val_accuracy: 0.8015 - val_auc: 0.8456 - val_precision: 0.6531 - val_recall: 0.5626 - val_f1_score: 0.6045 Epoch 21/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4193 - accuracy: 0.8072 - auc: 0.8510 - precision: 0.6682 - recall: 0.5546 - f1_score: 0.6061 - val_loss: 0.4265 - val_accuracy: 0.8021 - val_auc: 0.8454 - val_precision: 0.6596 - val_recall: 0.5495 - val_f1_score: 0.5995 Epoch 22/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4188 - accuracy: 0.8072 - auc: 0.8520 - precision: 0.6659 - recall: 0.5603 - f1_score: 0.6086 - val_loss: 0.4270 - val_accuracy: 0.7980 - val_auc: 0.8452 - val_precision: 0.6447 - val_recall: 0.5582 - val_f1_score: 0.5984 Epoch 23/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4168 - accuracy: 0.8065 - auc: 0.8526 - precision: 0.6663 - recall: 0.5537 - f1_score: 0.6048 - val_loss: 0.4263 - val_accuracy: 0.8004 - val_auc: 0.8464 - val_precision: 0.6578 - val_recall: 0.5407 - val_f1_score: 0.5935 Epoch 24/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4181 - accuracy: 0.8090 - auc: 0.8513 - precision: 0.6648 - recall: 0.5764 - f1_score: 0.6175 - val_loss: 0.4262 - val_accuracy: 0.7974 - val_auc: 0.8466 - val_precision: 0.6540 - val_recall: 0.5275 - val_f1_score: 0.5839 Epoch 25/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4165 - accuracy: 0.8090 - auc: 0.8533 - precision: 0.6744 - recall: 0.5527 - f1_score: 0.6075 - val_loss: 0.4250 - val_accuracy: 0.8009 - val_auc: 0.8468 - val_precision: 0.6630 - val_recall: 0.5319 - val_f1_score: 0.5902 Epoch 26/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4159 - accuracy: 0.8082 - auc: 0.8525 - precision: 0.6725 - recall: 0.5518 - f1_score: 0.6062 - val_loss: 0.4248 - val_accuracy: 0.8015 - val_auc: 0.8466 - val_precision: 0.6622 - val_recall: 0.5385 - val_f1_score: 0.5939 Epoch 27/35 79/79 [==============================] - 0s 4ms/step - loss: 0.4147 - accuracy: 0.8059 - auc: 0.8538 - precision: 0.6618 - recall: 0.5613 - f1_score: 0.6074 - val_loss: 0.4245 - val_accuracy: 0.8045 - val_auc: 0.8470 - val_precision: 0.6703 - val_recall: 0.5407 - val_f1_score: 0.5985 Epoch 28/35 79/79 [==============================] - 0s 4ms/step - loss: 0.4137 - accuracy: 0.8057 - auc: 0.8555 - precision: 0.6622 - recall: 0.5584 - f1_score: 0.6059 - val_loss: 0.4240 - val_accuracy: 0.8033 - val_auc: 0.8474 - val_precision: 0.6723 - val_recall: 0.5275 - val_f1_score: 0.5911 Epoch 29/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4129 - accuracy: 0.8100 - auc: 0.8546 - precision: 0.6751 - recall: 0.5584 - f1_score: 0.6112 - val_loss: 0.4240 - val_accuracy: 0.8033 - val_auc: 0.8469 - val_precision: 0.6606 - val_recall: 0.5560 - val_f1_score: 0.6038 Epoch 30/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4157 - accuracy: 0.8080 - auc: 0.8536 - precision: 0.6701 - recall: 0.5556 - f1_score: 0.6075 - val_loss: 0.4239 - val_accuracy: 0.8004 - val_auc: 0.8469 - val_precision: 0.6586 - val_recall: 0.5385 - val_f1_score: 0.5925 Epoch 31/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4111 - accuracy: 0.8082 - auc: 0.8564 - precision: 0.6663 - recall: 0.5670 - f1_score: 0.6126 - val_loss: 0.4243 - val_accuracy: 0.8015 - val_auc: 0.8467 - val_precision: 0.6571 - val_recall: 0.5516 - val_f1_score: 0.5998 Epoch 32/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4122 - accuracy: 0.8067 - auc: 0.8562 - precision: 0.6659 - recall: 0.5565 - f1_score: 0.6063 - val_loss: 0.4247 - val_accuracy: 0.7992 - val_auc: 0.8460 - val_precision: 0.6602 - val_recall: 0.5253 - val_f1_score: 0.5851 Epoch 33/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4130 - accuracy: 0.8087 - auc: 0.8562 - precision: 0.6685 - recall: 0.5651 - f1_score: 0.6125 - val_loss: 0.4239 - val_accuracy: 0.8027 - val_auc: 0.8470 - val_precision: 0.6548 - val_recall: 0.5670 - val_f1_score: 0.6078 Epoch 34/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4153 - accuracy: 0.8098 - auc: 0.8540 - precision: 0.6727 - recall: 0.5622 - f1_score: 0.6125 - val_loss: 0.4230 - val_accuracy: 0.8027 - val_auc: 0.8476 - val_precision: 0.6605 - val_recall: 0.5516 - val_f1_score: 0.6012 Epoch 35/35 79/79 [==============================] - 0s 3ms/step - loss: 0.4123 - accuracy: 0.8098 - auc: 0.8561 - precision: 0.6712 - recall: 0.5660 - f1_score: 0.6141 - val_loss: 0.4242 - val_accuracy: 0.7992 - val_auc: 0.8458 - val_precision: 0.6495 - val_recall: 0.5538 - val_f1_score: 0.5979
from matplotlib import pyplot as plt
%matplotlib inline
plt.plot(results.history['accuracy'], '-o') # 'o' is to show the markers, '-' is to draw the line
plt.plot(results.history['val_accuracy'], '-o')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='lower right')
<matplotlib.legend.Legend at 0x24a54283be0>
plt.plot(results.history['loss'], '-o')
plt.plot(results.history['val_loss'], '-o')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper right')
<matplotlib.legend.Legend at 0x24a56399310>
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objs as go
### Alternative for interactive plots using plotly - will need to rerun every time jupyter opens
plot_df = pd.DataFrame(results.history)
fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
fig.add_trace(go.Scatter(y=plot_df['loss'], mode='lines+markers', name='training loss'), row=1, col=1)
fig.add_trace(go.Scatter(y=plot_df['val_loss'], mode='lines+markers', name='validation loss'), row=1, col=1)
fig.add_trace(go.Scatter(y=plot_df['accuracy'], mode='lines+markers', name='training accuracy'), row=2, col=1)
fig.add_trace(go.Scatter(y=plot_df['val_accuracy'], mode='lines+markers', name='validation accuracy'), row=2, col=1)
fig.update_xaxes(title_text="epoch", row = 2, col = 1)
fig.update_yaxes(title_text="accuracy", row = 2, col = 1)
fig.update_yaxes(title_text="loss", row = 1, col = 1)
fig
# When we did model.compile above, we gave all the performance metrics that we wanted the model to return
test_loss, test_acc, test_auc, test_precision, test_recall, f1_score = model.evaluate(test_data, test_labels)
44/44 [==============================] - 0s 2ms/step - loss: 0.4042 - accuracy: 0.8117 - auc: 0.8547 - precision: 0.6500 - recall: 0.5762 - f1_score: 0.6109
print(f"The loss on the test data is {test_loss}, the accuracy is {test_acc}")
The loss on the test data is 0.40417763590812683, the accuracy is 0.8116559982299805
print(f"The AUC is {test_auc}")
The AUC is 0.8547202944755554
print(f"Precision: {test_precision}, Recall: {test_recall}")
Precision: 0.6499999761581421, Recall: 0.5761772990226746
print(f"F1 Score: {f1_score[0]}")
F1 Score: 0.6108664274215698
predictions_prob = model.predict(test_data)
predictions = np.round(predictions_prob)
44/44 [==============================] - 0s 1ms/step
# Test data and predictions
result_df = pd.DataFrame({"truth": test_data_raw["Churn"].reset_index(drop=True),
"estimate": pd.Series(predictions.flatten()).map({1: "Yes", 0: "No"}),
"class_prob": predictions_prob.flatten()})
result_df
| truth | estimate | class_prob | |
|---|---|---|---|
| 0 | No | No | 0.002450 |
| 1 | No | No | 0.001714 |
| 2 | No | No | 0.097320 |
| 3 | No | No | 0.082025 |
| 4 | No | No | 0.059745 |
| ... | ... | ... | ... |
| 1402 | No | No | 0.063977 |
| 1403 | No | No | 0.057431 |
| 1404 | No | No | 0.004145 |
| 1405 | No | No | 0.144564 |
| 1406 | No | No | 0.150562 |
1407 rows × 3 columns
# to show 10 random samples, because the above happened to be 'No' only
result_df.sample(10)
| truth | estimate | class_prob | |
|---|---|---|---|
| 635 | Yes | No | 0.156748 |
| 1344 | No | No | 0.069637 |
| 1399 | No | No | 0.040246 |
| 1340 | No | No | 0.158345 |
| 194 | Yes | No | 0.439218 |
| 172 | No | No | 0.039678 |
| 915 | No | Yes | 0.603498 |
| 594 | No | No | 0.381667 |
| 414 | No | No | 0.106315 |
| 1235 | No | No | 0.449783 |
from sklearn.metrics import confusion_matrix
conf_matrix = confusion_matrix(result_df["truth"], result_df["estimate"] , normalize='pred')
conf_matrix
array([[0.85924563, 0.35 ],
[0.14075437, 0.65 ]])
conf_matrix = confusion_matrix(result_df["truth"], result_df["estimate"])
conf_matrix
array([[934, 112],
[153, 208]], dtype=int64)
from sklearn.metrics import ConfusionMatrixDisplay
disp = ConfusionMatrixDisplay(conf_matrix, display_labels=["No", "Yes"])
disp.plot()
<sklearn.metrics._plot.confusion_matrix.ConfusionMatrixDisplay at 0x24a56a70d60>